#!/usr/bin/env python3
"""
Complete Data Analysis Pipeline for Subtitle Support Study
==========================================================

This script provides a fully reproducible analysis pipeline from raw physiological
data files to final statistical results.

Study Design:
- Within-subject repeated measures (n=15)
- Two conditions: Block1 (with subtitles), Block2 (without subtitles)
- Five physiological channels: EEG (AF3 theta), GSR, BVP, Heart Rate, Respiration Rate

Author: [Research Team]
Date: November 2025
Python Version: 3.11+
Required Packages: pandas, numpy, scipy, matplotlib
"""

import os
import glob
import pandas as pd
import numpy as np
from scipy import signal, stats
from scipy.stats import trim_mean
import warnings
warnings.filterwarnings('ignore')

# =============================================================================
# SECTION 1: DATA LOADING AND PREPROCESSING
# =============================================================================

def load_raw_data(filepath):
    """
    Load raw physiological data from tab-delimited txt files.
    
    Parameters:
    -----------
    filepath : str
        Path to the raw data file
        
    Returns:
    --------
    df : pandas.DataFrame
        Raw data with columns for each physiological channel
    """
    try:
        # Read tab-delimited file
        df = pd.read_csv(filepath, sep='\t', encoding='utf-8')
        
        # Standardize column names
        df.columns = [col.strip().lower().replace(' ', '_') for col in df.columns]
        
        return df
    except Exception as e:
        print(f"Error loading {filepath}: {e}")
        return None


def extract_participant_info(filename):
    """
    Extract participant ID and condition from filename.
    
    Expected format: P[ID]_Block[1/2].txt
    Example: P01_Block1.txt -> participant=1, condition='with_subtitles'
    
    Parameters:
    -----------
    filename : str
        Name of the data file
        
    Returns:
    --------
    tuple : (participant_id, condition)
    """
    basename = os.path.basename(filename)
    
    # Extract participant ID
    if 'P' in basename:
        participant_id = int(basename.split('P')[1].split('_')[0])
    else:
        participant_id = None
    
    # Extract condition
    if 'Block1' in basename:
        condition = 'with_subtitles'
    elif 'Block2' in basename:
        condition = 'without_subtitles'
    else:
        condition = None
    
    return participant_id, condition


# =============================================================================
# SECTION 2: SIGNAL PROCESSING FUNCTIONS
# =============================================================================

def process_eeg_theta(eeg_signal, sampling_rate=256, freq_band=(4, 7)):
    """
    Extract theta band (4-7 Hz) power from EEG signal using Welch's method.
    
    Parameters:
    -----------
    eeg_signal : array-like
        Raw EEG signal from AF3 electrode
    sampling_rate : int
        Sampling rate in Hz (default: 256)
    freq_band : tuple
        Frequency band for theta (default: 4-7 Hz)
        
    Returns:
    --------
    theta_power : float
        Mean theta band power (log-transformed)
    """
    # Remove NaN values
    eeg_clean = eeg_signal[~np.isnan(eeg_signal)]
    
    if len(eeg_clean) < sampling_rate * 2:
        return np.nan
    
    # Welch's method for power spectral density
    freqs, psd = signal.welch(eeg_clean, fs=sampling_rate, 
                              nperseg=min(sampling_rate * 2, len(eeg_clean)))
    
    # Extract theta band power
    theta_idx = np.logical_and(freqs >= freq_band[0], freqs <= freq_band[1])
    theta_power = np.mean(psd[theta_idx])
    
    # Log transform for normalization
    theta_power_log = np.log10(theta_power + 1e-10)
    
    return theta_power_log


def process_gsr(gsr_signal, sampling_rate=32, cutoff_freq=0.5):
    """
    Extract tonic skin conductance level (SCL) from GSR signal.
    
    Parameters:
    -----------
    gsr_signal : array-like
        Raw GSR/EDA signal
    sampling_rate : int
        Sampling rate in Hz (default: 32)
    cutoff_freq : float
        Low-pass filter cutoff frequency (default: 0.5 Hz)
        
    Returns:
    --------
    scl_mean : float
        Mean tonic skin conductance level in µS
    """
    # Remove NaN values
    gsr_clean = gsr_signal[~np.isnan(gsr_signal)]
    
    if len(gsr_clean) < sampling_rate:
        return np.nan
    
    # Low-pass filter to isolate tonic component
    nyquist = sampling_rate / 2
    b, a = signal.butter(3, cutoff_freq / nyquist, btype='low')
    gsr_tonic = signal.filtfilt(b, a, gsr_clean)
    
    # Mean tonic level
    scl_mean = np.mean(gsr_tonic)
    
    return scl_mean


def process_bvp(bvp_signal, sampling_rate=128):
    """
    Extract BVP amplitude from blood volume pulse signal.
    
    Parameters:
    -----------
    bvp_signal : array-like
        Raw BVP signal
    sampling_rate : int
        Sampling rate in Hz (default: 128)
        
    Returns:
    --------
    bvp_amplitude : float
        Mean BVP amplitude (peak-to-trough)
    """
    # Remove NaN values
    bvp_clean = bvp_signal[~np.isnan(bvp_signal)]
    
    if len(bvp_clean) < sampling_rate * 2:
        return np.nan
    
    # Low-pass filter at 10 Hz
    nyquist = sampling_rate / 2
    b, a = signal.butter(3, 10 / nyquist, btype='low')
    bvp_filtered = signal.filtfilt(b, a, bvp_clean)
    
    # Detect peaks (systolic)
    min_distance = int(sampling_rate * 0.5)  # Minimum 0.5 sec between peaks (max 120 BPM)
    peaks, _ = signal.find_peaks(bvp_filtered, distance=min_distance)
    
    if len(peaks) < 2:
        return np.nan
    
    # Calculate amplitude for each pulse
    amplitudes = []
    for i in range(1, len(peaks)):
        peak_val = bvp_filtered[peaks[i]]
        trough_val = np.min(bvp_filtered[peaks[i-1]:peaks[i]])
        amplitudes.append(peak_val - trough_val)
    
    bvp_amplitude = np.mean(amplitudes) if amplitudes else np.nan
    
    return bvp_amplitude


def process_heart_rate(bvp_signal, sampling_rate=128):
    """
    Calculate heart rate from BVP signal using inter-beat intervals.
    
    Parameters:
    -----------
    bvp_signal : array-like
        Raw BVP signal
    sampling_rate : int
        Sampling rate in Hz (default: 128)
        
    Returns:
    --------
    heart_rate : float
        Mean heart rate in beats per minute (BPM)
    """
    # Remove NaN values
    bvp_clean = bvp_signal[~np.isnan(bvp_signal)]
    
    if len(bvp_clean) < sampling_rate * 2:
        return np.nan
    
    # Low-pass filter
    nyquist = sampling_rate / 2
    b, a = signal.butter(3, 10 / nyquist, btype='low')
    bvp_filtered = signal.filtfilt(b, a, bvp_clean)
    
    # Detect peaks
    min_distance = int(sampling_rate * 0.5)
    peaks, _ = signal.find_peaks(bvp_filtered, distance=min_distance)
    
    if len(peaks) < 2:
        return np.nan
    
    # Calculate inter-beat intervals (IBI)
    ibis = np.diff(peaks) / sampling_rate * 1000  # Convert to milliseconds
    
    # Convert to heart rate (BPM)
    heart_rates = 60000 / ibis
    
    # Filter physiologically plausible values (40-180 BPM)
    heart_rates = heart_rates[(heart_rates >= 40) & (heart_rates <= 180)]
    
    heart_rate = np.mean(heart_rates) if len(heart_rates) > 0 else np.nan
    
    return heart_rate


def process_respiration(resp_signal, sampling_rate=32):
    """
    Calculate respiration rate from respiration belt signal.
    
    Parameters:
    -----------
    resp_signal : array-like
        Raw respiration signal
    sampling_rate : int
        Sampling rate in Hz (default: 32)
        
    Returns:
    --------
    respiration_rate : float
        Mean respiration rate in breaths per minute
    """
    # Remove NaN values
    resp_clean = resp_signal[~np.isnan(resp_signal)]
    
    if len(resp_clean) < sampling_rate * 10:
        return np.nan
    
    # Band-pass filter (0.05-1 Hz for respiration)
    nyquist = sampling_rate / 2
    b, a = signal.butter(3, [0.05 / nyquist, 1.0 / nyquist], btype='band')
    resp_filtered = signal.filtfilt(b, a, resp_clean)
    
    # Detect peaks (inspiratory)
    min_distance = int(sampling_rate * 1.5)  # Minimum 1.5 sec between breaths
    peaks, _ = signal.find_peaks(resp_filtered, distance=min_distance)
    
    if len(peaks) < 2:
        return np.nan
    
    # Calculate inter-breath intervals
    ibis = np.diff(peaks) / sampling_rate  # In seconds
    
    # Convert to breaths per minute
    breath_rates = 60 / ibis
    
    # Filter physiologically plausible values (8-40 breaths/min)
    breath_rates = breath_rates[(breath_rates >= 8) & (breath_rates <= 40)]
    
    respiration_rate = np.mean(breath_rates) if len(breath_rates) > 0 else np.nan
    
    return respiration_rate


# =============================================================================
# SECTION 3: FEATURE EXTRACTION PIPELINE
# =============================================================================

def extract_features_from_file(filepath):
    """
    Extract all physiological features from a single data file.
    
    Parameters:
    -----------
    filepath : str
        Path to raw data file
        
    Returns:
    --------
    features : dict
        Dictionary containing all extracted features
    """
    # Load data
    df = load_raw_data(filepath)
    if df is None:
        return None
    
    # Extract metadata
    participant_id, condition = extract_participant_info(filepath)
    
    # Initialize feature dictionary
    features = {
        'participant_id': participant_id,
        'condition': condition,
        'filename': os.path.basename(filepath)
    }
    
    # Map column names (adjust based on actual column names in your data)
    # Expected columns: sensor-a (EEG), sensor-e (GSR), sensor-g (BVP), 
    #                   heart_rate, respiration_rate
    
    try:
        # Process EEG theta power
        if 'sensor-a' in df.columns or 'eeg' in df.columns:
            eeg_col = 'sensor-a' if 'sensor-a' in df.columns else 'eeg'
            features['theta_power'] = process_eeg_theta(df[eeg_col].values)
        else:
            features['theta_power'] = np.nan
        
        # Process GSR
        if 'sensor-e' in df.columns or 'gsr' in df.columns or 'sc' in df.columns:
            gsr_col = [c for c in df.columns if c in ['sensor-e', 'gsr', 'sc']][0]
            features['gsr_tonic'] = process_gsr(df[gsr_col].values)
        else:
            features['gsr_tonic'] = np.nan
        
        # Process BVP
        if 'sensor-g' in df.columns or 'bvp' in df.columns:
            bvp_col = 'sensor-g' if 'sensor-g' in df.columns else 'bvp'
            features['bvp_amplitude'] = process_bvp(df[bvp_col].values)
            features['heart_rate'] = process_heart_rate(df[bvp_col].values)
        else:
            features['bvp_amplitude'] = np.nan
            features['heart_rate'] = np.nan
        
        # Process Respiration
        if 'respiration' in df.columns or 'resp' in df.columns:
            resp_col = 'respiration' if 'respiration' in df.columns else 'resp'
            features['respiration_rate'] = process_respiration(df[resp_col].values)
        else:
            features['respiration_rate'] = np.nan
            
    except Exception as e:
        print(f"Error processing {filepath}: {e}")
        return None
    
    return features


def process_all_files(data_directory):
    """
    Process all data files in the specified directory.
    
    Parameters:
    -----------
    data_directory : str
        Path to directory containing all raw data files
        
    Returns:
    --------
    df_features : pandas.DataFrame
        DataFrame containing all extracted features
    """
    # Find all txt files
    file_pattern = os.path.join(data_directory, "*.txt")
    files = glob.glob(file_pattern)
    
    print(f"Found {len(files)} data files")
    
    # Extract features from each file
    all_features = []
    for filepath in files:
        print(f"Processing: {os.path.basename(filepath)}")
        features = extract_features_from_file(filepath)
        if features is not None:
            all_features.append(features)
    
    # Create DataFrame
    df_features = pd.DataFrame(all_features)
    
    # Sort by participant and condition
    df_features = df_features.sort_values(['participant_id', 'condition'])
    
    print(f"\nSuccessfully processed {len(df_features)} files")
    print(f"Participants: {df_features['participant_id'].nunique()}")
    print(f"Conditions: {df_features['condition'].unique()}")
    
    return df_features


# =============================================================================
# SECTION 4: STATISTICAL ANALYSIS FUNCTIONS
# =============================================================================

def robust_paired_ttest(x, y, trim_proportion=0.2):
    """
    Perform robust paired-samples t-test with trimmed means.
    
    This method reduces the influence of outliers by trimming extreme values
    before computing the test statistic.
    
    Parameters:
    -----------
    x, y : array-like
        Paired samples (same length)
    trim_proportion : float
        Proportion to trim from each end (default: 0.2 = 20%)
        
    Returns:
    --------
    results : dict
        Dictionary containing t-statistic, df, p-values, and descriptive stats
    """
    x = np.array(x)
    y = np.array(y)
    n = len(x)
    
    # Calculate trimmed means
    x_trimmed_mean = trim_mean(x, trim_proportion)
    y_trimmed_mean = trim_mean(y, trim_proportion)
    
    # Calculate degrees of freedom after trimming
    g = int(trim_proportion * n)
    df = n - 2 * g - 1
    
    # Winsorize for variance calculation
    x_sorted = np.sort(x)
    y_sorted = np.sort(y)
    
    x_winsor = x_sorted.copy()
    y_winsor = y_sorted.copy()
    
    x_winsor[:g] = x_winsor[g]
    x_winsor[-g:] = x_winsor[-g-1]
    y_winsor[:g] = y_winsor[g]
    y_winsor[-g:] = y_winsor[-g-1]
    
    # Calculate winsorized variances
    var_x = np.var(x_winsor, ddof=1)
    var_y = np.var(y_winsor, ddof=1)
    
    # Standard error
    se = np.sqrt((var_x + var_y) / (n * (1 - 2 * trim_proportion)**2))
    
    # t-statistic
    t_stat = (x_trimmed_mean - y_trimmed_mean) / se
    
    # p-values
    from scipy.stats import t as t_dist
    p_two_tailed = 2 * (1 - t_dist.cdf(np.abs(t_stat), df))
    p_one_tailed = 1 - t_dist.cdf(t_stat, df) if t_stat > 0 else t_dist.cdf(t_stat, df)
    
    # Effect size (Cohen's d) using full data
    pooled_std = np.sqrt((np.var(x, ddof=1) + np.var(y, ddof=1)) / 2)
    cohens_d = (np.mean(x) - np.mean(y)) / pooled_std
    
    results = {
        't_statistic': t_stat,
        'df': df,
        'p_two_tailed': p_two_tailed,
        'p_one_tailed': p_one_tailed,
        'cohens_d': cohens_d,
        'x_mean': np.mean(x),
        'y_mean': np.mean(y),
        'x_trimmed_mean': x_trimmed_mean,
        'y_trimmed_mean': y_trimmed_mean,
        'mean_difference': np.mean(x) - np.mean(y),
        'n': n,
        'n_trimmed': n - 2 * g
    }
    
    return results


def calculate_weighted_composite_index(df_features, measures, weights=None):
    """
    Calculate weighted composite stress index from multiple physiological measures.
    
    Parameters:
    -----------
    df_features : pandas.DataFrame
        DataFrame containing all features
    measures : list
        List of column names to include in composite
    weights : dict, optional
        Dictionary of weights for each measure (default: equal weights)
        
    Returns:
    --------
    df_composite : pandas.DataFrame
        DataFrame with added 'composite_index' column
    """
    df_composite = df_features.copy()
    
    # Standardize each measure (z-score)
    for measure in measures:
        df_composite[f'{measure}_z'] = (
            (df_composite[measure] - df_composite[measure].mean()) / 
            df_composite[measure].std()
        )
    
    # Apply weights
    if weights is None:
        weights = {m: 1.0 / len(measures) for m in measures}
    
    # Normalize weights to sum to 1
    total_weight = sum(weights.values())
    weights = {k: v / total_weight for k, v in weights.items()}
    
    # Calculate composite index
    df_composite['composite_index'] = 0
    for measure, weight in weights.items():
        df_composite['composite_index'] += df_composite[f'{measure}_z'] * weight
    
    return df_composite, weights


# =============================================================================
# SECTION 5: MAIN ANALYSIS PIPELINE
# =============================================================================

def main_analysis(data_directory, output_directory):
    """
    Execute complete analysis pipeline.
    
    Parameters:
    -----------
    data_directory : str
        Path to directory containing raw data files
    output_directory : str
        Path to directory for saving results
    """
    print("="*70)
    print("COMPLETE ANALYSIS PIPELINE FOR SUBTITLE SUPPORT STUDY")
    print("="*70)
    
    # Create output directory if it doesn't exist
    os.makedirs(output_directory, exist_ok=True)
    
    # -------------------------------------------------------------------------
    # STEP 1: Feature Extraction
    # -------------------------------------------------------------------------
    print("\n[STEP 1] Extracting features from raw data files...")
    df_features = process_all_files(data_directory)
    
    # Save extracted features
    features_file = os.path.join(output_directory, 'extracted_features.csv')
    df_features.to_csv(features_file, index=False)
    print(f"Saved extracted features to: {features_file}")
    
    # -------------------------------------------------------------------------
    # STEP 2: Primary Analysis - Robust t-tests
    # -------------------------------------------------------------------------
    print("\n[STEP 2] Performing robust paired t-tests...")
    
    measures_to_analyze = {
        'bvp_amplitude': 'BVP Amplitude',
        'heart_rate': 'Heart Rate',
        'gsr_tonic': 'GSR (Tonic)',
        'respiration_rate': 'Respiration Rate'
    }
    
    # Separate conditions
    with_subtitles = df_features[df_features['condition'] == 'with_subtitles']
    without_subtitles = df_features[df_features['condition'] == 'without_subtitles']
    
    # Ensure same participants in both conditions
    common_participants = set(with_subtitles['participant_id']) & set(without_subtitles['participant_id'])
    with_subtitles = with_subtitles[with_subtitles['participant_id'].isin(common_participants)]
    without_subtitles = without_subtitles[without_subtitles['participant_id'].isin(common_participants)]
    
    # Sort by participant ID
    with_subtitles = with_subtitles.sort_values('participant_id')
    without_subtitles = without_subtitles.sort_values('participant_id')
    
    # Perform robust t-tests
    robust_results = []
    for measure_col, measure_name in measures_to_analyze.items():
        x = with_subtitles[measure_col].values
        y = without_subtitles[measure_col].values
        
        # Remove pairs with missing data
        valid_mask = ~(np.isnan(x) | np.isnan(y))
        x_valid = x[valid_mask]
        y_valid = y[valid_mask]
        
        if len(x_valid) < 5:
            print(f"Warning: Insufficient data for {measure_name}")
            continue
        
        results = robust_paired_ttest(x_valid, y_valid, trim_proportion=0.2)
        results['measure'] = measure_name
        robust_results.append(results)
        
        print(f"\n{measure_name}:")
        print(f"  With subtitles: M = {results['x_mean']:.2f}")
        print(f"  Without subtitles: M = {results['y_mean']:.2f}")
        print(f"  t({results['df']}) = {results['t_statistic']:.3f}")
        print(f"  p (one-tailed) = {results['p_one_tailed']:.4f}")
        print(f"  Cohen's d = {results['cohens_d']:.3f}")
    
    # Save robust t-test results
    df_robust = pd.DataFrame(robust_results)
    robust_file = os.path.join(output_directory, 'robust_ttest_results.csv')
    df_robust.to_csv(robust_file, index=False)
    print(f"\nSaved robust t-test results to: {robust_file}")
    
    # -------------------------------------------------------------------------
    # STEP 3: Supplementary Analysis - Weighted Composite Index
    # -------------------------------------------------------------------------
    print("\n[STEP 3] Calculating weighted composite stress index...")
    
    # Calculate weights based on effect sizes
    weights = {}
    for result in robust_results:
        measure_col = [k for k, v in measures_to_analyze.items() if v == result['measure']][0]
        weights[measure_col] = abs(result['cohens_d'])
    
    print("\nWeights based on effect sizes:")
    for measure_col, weight in weights.items():
        print(f"  {measures_to_analyze[measure_col]}: {weight:.4f}")
    
    # Calculate composite index
    df_composite, normalized_weights = calculate_weighted_composite_index(
        df_features, 
        list(measures_to_analyze.keys()),
        weights
    )
    
    print("\nNormalized weights:")
    for measure_col, weight in normalized_weights.items():
        print(f"  {measures_to_analyze[measure_col]}: {weight:.4f}")
    
    # Perform t-test on composite index
    composite_with = df_composite[df_composite['condition'] == 'with_subtitles']['composite_index'].values
    composite_without = df_composite[df_composite['condition'] == 'without_subtitles']['composite_index'].values
    
    # Remove missing data
    valid_mask = ~(np.isnan(composite_with) | np.isnan(composite_without))
    composite_with = composite_with[valid_mask]
    composite_without = composite_without[valid_mask]
    
    t_stat, p_two = stats.ttest_rel(composite_with, composite_without)
    p_one = p_two / 2 if t_stat < 0 else 1 - (p_two / 2)
    
    pooled_std = np.sqrt((np.var(composite_with, ddof=1) + np.var(composite_without, ddof=1)) / 2)
    cohens_d = (np.mean(composite_with) - np.mean(composite_without)) / pooled_std
    
    print("\nWeighted Composite Stress Index:")
    print(f"  With subtitles: M = {np.mean(composite_with):.3f}, SD = {np.std(composite_with, ddof=1):.3f}")
    print(f"  Without subtitles: M = {np.mean(composite_without):.3f}, SD = {np.std(composite_without, ddof=1):.3f}")
    print(f"  t({len(composite_with)-1}) = {t_stat:.3f}")
    print(f"  p (one-tailed) = {p_one:.4f}")
    print(f"  Cohen's d = {cohens_d:.3f}")
    
    # Save composite results
    composite_file = os.path.join(output_directory, 'composite_index_results.csv')
    df_composite.to_csv(composite_file, index=False)
    print(f"\nSaved composite index results to: {composite_file}")
    
    # -------------------------------------------------------------------------
    # STEP 4: Generate Summary Report
    # -------------------------------------------------------------------------
    print("\n[STEP 4] Generating summary report...")
    
    summary = []
    summary.append("="*70)
    summary.append("STATISTICAL ANALYSIS SUMMARY")
    summary.append("="*70)
    summary.append("")
    summary.append("PRIMARY ANALYSIS: Robust Paired t-tests (20% trimming)")
    summary.append("-"*70)
    summary.append(f"{'Measure':<25} {'t(df)':<15} {'p (1-tail)':<12} {'d':<8}")
    summary.append("-"*70)
    
    for result in robust_results:
        summary.append(
            f"{result['measure']:<25} "
            f"{result['t_statistic']:.3f}({result['df']:<2}){'':<6} "
            f"{result['p_one_tailed']:<12.4f} "
            f"{result['cohens_d']:<8.3f}"
        )
    
    summary.append("")
    summary.append("SUPPLEMENTARY ANALYSIS: Weighted Composite Stress Index")
    summary.append("-"*70)
    summary.append(f"t({len(composite_with)-1}) = {t_stat:.3f}, p (one-tailed) = {p_one:.4f}, d = {cohens_d:.3f}")
    summary.append("")
    summary.append("="*70)
    
    summary_text = "\n".join(summary)
    print("\n" + summary_text)
    
    # Save summary
    summary_file = os.path.join(output_directory, 'analysis_summary.txt')
    with open(summary_file, 'w') as f:
        f.write(summary_text)
    print(f"\nSaved summary to: {summary_file}")
    
    print("\n" + "="*70)
    print("ANALYSIS COMPLETE")
    print("="*70)


# =============================================================================
# SECTION 6: EXECUTION
# =============================================================================

if __name__ == "__main__":
    # Set paths
    DATA_DIRECTORY = "/home/ubuntu/upload"  # Adjust this path
    OUTPUT_DIRECTORY = "/home/ubuntu/analysis_results"  # Adjust this path
    
    # Run analysis
    main_analysis(DATA_DIRECTORY, OUTPUT_DIRECTORY)
